import argparse
import os

from omegaconf import OmegaConf, DictConfig


def load_config(print_config=True):
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str,
                        default='configs/imagenetr-ti2i/imagenetr-ti2i-o.yaml',
                        help="Config file path")
    parser.add_argument(
        "--opts",
        help="other configurations",
        default=None,
        nargs='+',)
    args = parser.parse_args()
    config = OmegaConf.load(args.config)

    # Recursively merge base configs
    cur_config_path = args.config
    cur_config = config
    while "base_config" in cur_config and cur_config.base_config != cur_config_path:
        base_config = OmegaConf.load(cur_config.base_config)
        config = OmegaConf.merge(base_config, config)
        cur_config_path = cur_config.base_config
        cur_config = base_config

    OmegaConf.resolve(config)

    if args.opts is not None:
        keys = args.opts[::2]
        values = args.opts[1::2]
        for k, v in zip(keys, values):
            OmegaConf.update(config, k, v)

    if print_config:
        print("[INFO] loaded config:")
        print(OmegaConf.to_yaml(config))

    return config



def save_config(config: DictConfig, path, gene=False, inv=False):
    os.makedirs(path, exist_ok=True)
    config = OmegaConf.create(config)
    if gene:
        config.pop("inversion")
    if inv:
        config.pop("generation")
    OmegaConf.save(config, os.path.join(path, "config.yaml"))
